# -*- coding: utf-8 -*-
"""
Created on Thu Mar 30 16:40:57 2023

@author: lv
"""

import torch
import torch.nn as nn
from .SelfAttentionEncoder import SelfAttentionEncoder

class UserIntentAttention(nn.Module):
    def __init__(self,max_seq_length, 
                 output_num_sequences, 
                 input_max_num_sequences, 
                 tokenizer_size,
                 return_scores=False,
                 hidden_dims=[1024, 512, 1024]):
        super(UserIntentAttention, self).__init__()

        self.tokenizer_size = tokenizer_size
        
        # 实例化SelfAttentionEncoder层
        self.self_attention_encoder = SelfAttentionEncoder(max_seq_length=max_seq_length,
                                               output_num_sequences=output_num_sequences,
                                               input_max_num_sequences=input_max_num_sequences,
                                               tokenizer_size=tokenizer_size,
                                               return_scores=True,
                                               hidden_dims=hidden_dims)
        #是否返回得分
        self.return_scores = return_scores


    def forward(self, input_ids, intent_template_ids, attention_mask=None):
        
        #合并intent_keyword_ids与input_ids
        combined_ids = torch.cat((input_ids, intent_template_ids), dim=0)

        # 使用SelfAttentionEncoder编码输入序列
        attention_weights = self.self_attention_encoder(combined_ids, attention_mask=attention_mask)

        #如果需要返回得分，则直接返回
        if self.return_scores == True:
            return attention_weights
        
        # 取得attention weights最大值所在的位置
        max_indices = self.self_attention_encoder.embedding.decode(attention_weights)
        max_indices = torch.clamp(max_indices, max=self.tokenizer_size)
        return max_indices


